# -*- coding: utf-8 -*-
"""
Created on Thu Jul  8 17:54:31 2021

@author: SonHyungmok
"""


import numpy as np
import matplotlib.pyplot as plt
import os
from os import listdir
from os.path import isfile, join
from scipy.optimize import curve_fit

from scipy.integrate import odeint

class lc():
    def __init__(self):
        self.nothing = 1.
    
    def processData_internal(self, fileName, varindex = 1, dataindex = 10, kind = 'mol', flgerr = 'stderr'):
        callThis = fileName
        if kind == 'mol':
            polarization_correct_mol = 0.9
            num_pancakes_mol = 405. * np.sqrt(8/np.pi)/(1596/2 * 1e-3)
            correct = polarization_correct_mol * num_pancakes_mol
        else:
            num_pancakes_na = 662. *1e-6 * np.sqrt(8/np.pi)/(0.5 * 1596e-9)
            polarization_correct = 0.65
            correct = polarization_correct * num_pancakes_na

        data = np.loadtxt(callThis, delimiter='\t', unpack=True, usecols=(varindex,dataindex))
        varindex = 0
        dataindex = 1
        x = data[varindex]
        y = data[dataindex]/correct
        x_sorted = np.sort(x)

        idx_sorted = np.argsort(x)
        y_sorted = y[idx_sorted]

        var = []
        avg = []
        err = []
        try:
            for i, element in enumerate(x_sorted):
                if i == 0: 
                    i_begin = 0
                    prev_element = element

                if element != prev_element or i == len(x_sorted)-1:

                    i_end = i -1 
                    if i == len(x_sorted)-1:
                        i_end = i

                    arr = y_sorted[i_begin:i_end+1]
                    if x_sorted[i_begin] != x_sorted[i_end] and i != (len(x_sorted)-1):
                        raise Exception()

                    variable = x_sorted[i_begin]
                    var.append(variable)
                    avg.append(np.mean(arr))
                    if flgerr == 'stdev':
                        e = np.std(arr)
                    else:
                        ## this is standard error
                        e = np.std(arr)/np.sqrt(len(arr))      

                    err.append(e)    

                    i_begin = i
                prev_element = element

        except Exception:
            print(" --- something went wrong ----")

        out = np.array([np.array(var), np.array(avg), np.array(err)])
        return out

    
    def calc2(self, ww):
        working_dir = ww
        def twobody(t, initnum, gamma2):
            return initnum/(gamma2 * t + 1)

        def onebody(t, amp, gamma):
            return amp * np.exp(-gamma*t)

        def fullModel(n, t, gamma1, gamma2):
            return -gamma1 * n - gamma2 * n**2

        flgerr = True
        fileName_list = ['976G', '999G']
        color_list = ['blue', 'red', 'black', 'darkgreen', 'orangered']

        print(working_dir)
        fileName = '978G_molonly'
        out = self.processData_internal(join(working_dir, fileName + '.txt'))
            
        time = np.sort(out[0])
        num = out[1]
        err = out[2]

        ################
        ## two-body fit
        popt2, pcov2 = curve_fit(twobody, time, num, p0=[num[0], 1e-4])
        print("two body loss rate = {} (1/sec)".format(popt2[-1]*1000.))
        twobody_coeff = popt2[-1]/popt2[0] ## this is beta constant divided by volume

        ###################################
        def onetwobody(t, initnum, gamma1):
            N0 = initnum
            gamma2 = twobody_coeff 
            sol = odeint(fullModel, N0, t, args=(gamma1, gamma2))  
            return sol[:, 0]
        ###################################
        tt = np.linspace(0, max(time), 2000)
        i = 0
        
        time_list = [time]
        data_list = [num]
        err_list = [err]
        fit_list = [twobody(tt, *popt2)]
        for fileName, color in zip(fileName_list, color_list):
        #     out = np.loadtxt(callThis, delimiter='\t', unpack=True)
            out = self.processData_internal(join(working_dir, fileName + '.txt'))

            x = out[0]
            num = out[1]

            init_num = num[0]

            if flgerr == True:
                err = out[2]
                
            t = x


            #######
            ## curve_fit to the full model with ode_int
            popt, pcov = curve_fit(onetwobody, t, out[1], p0=[out[1][0], 1e-1])
            
            time_list.append(x)
            data_list.append(num)
            err_list.append(err)
            fit_list.append(onetwobody(tt, *popt))
            
            
#            plt.plot(tt, onetwobody(tt, *popt)/popt[0], c = color)
            print("for " + str(fileName) + ", loss rate = {} (1/sec)".format(popt[-1]*1000))
            
            ##############################
            ########## fit Na @ 980G
            if fileName == '980G':
                out = self.processData_internal(join(working_dir, fileName + '.txt'), dataindex=8, kind = 'na')
                x = out[0]
                nanum = out[1]
                nanum_err = out[2]
                
                init_num_ratio = nanum[0]/num[0]
                init_num_diff = nanum[0] - num[0]
                def func_pa_na_fix(t, gamma):
                    C = init_num_ratio
                    D = init_num_diff
                    return -D/(C**-1 * np.exp(-D*gamma*t) - 1)
                
                popt_na, pcov_na = curve_fit(func_pa_na_fix, t, nanum, p0=[1e-5])
                tt_temp = np.linspace(0, max(t), 200)
                na_packed = [x, nanum, nanum_err, tt_temp, func_pa_na_fix(tt_temp, *popt_na)]
            else:
                na_packed = []
                
            i += 1 


        ########
        ## simple one-body fit
        label_list = ['977.7G (NaLi only)', '976.2G', '998.9G']
        return [tt, fit_list, time_list, data_list, err_list, label_list, na_packed]
#        
if __name__ == "__main__":
    working_dir = os.getcwd()
    [tt, fit_list, time_list, data_list, err_list, label_list, na_packed] = lc().calc2(working_dir)
    
    plt.scatter(na_packed[0], na_packed[1])
    plt.errorbar(na_packed[0], na_packed[1], yerr = na_packed[2], ls='')
    plt.plot(na_packed[-2], na_packed[-1])
    plt.ylim([0, 150])
    (label_list)